{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License. \n",
    "\n",
    "# Datalist Cross-Validation Folds Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook contains an example to add cross-validation folds to an existing Medical Segmentation Decathlon datalist, in this case the one of Task09_Spleen. \n",
    "When running repeated experiments, it can be beneficial to create cross-validation folds beforehand."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!python3 -c \"import monai\" || pip install -q \"monai\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 1.2.0rc4+5.gfe550a12\n",
      "Numpy version: 1.22.2\n",
      "Pytorch version: 1.14.0a0+410ce96\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: fe550a12a4b2205d7c374f0436bfa0081505f668\n",
      "MONAI __file__: /workspace/monai/monai-in-dev/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.11\n",
      "ITK version: 5.3.0\n",
      "Nibabel version: 5.0.1\n",
      "scikit-image version: 0.20.0\n",
      "Pillow version: 9.2.0\n",
      "Tensorboard version: 2.9.0\n",
      "gdown version: 4.6.4\n",
      "TorchVision version: 0.15.0a0\n",
      "tqdm version: 4.64.1\n",
      "lmdb version: 1.4.0\n",
      "psutil version: 5.9.4\n",
      "pandas version: 1.5.2\n",
      "einops version: 0.6.0\n",
      "transformers version: 4.21.3\n",
      "mlflow version: 2.2.2\n",
      "pynrrd version: 1.0.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import random\n",
    "import shutil\n",
    "import tempfile\n",
    "from monai.config import print_config\n",
    "from monai.apps import download_and_extract\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Setup paths to your data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/data\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Download sample MSD Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "msd_task = \"Task09_Spleen\"\n",
    "resource = \"https://msd-for-monai.s3-us-west-2.amazonaws.com/\" + msd_task + \".tar\"\n",
    "\n",
    "compressed_file = os.path.join(root_dir, msd_task + \".tar\")\n",
    "dataroot = os.path.join(root_dir, msd_task)\n",
    "\n",
    "if not os.path.exists(dataroot):\n",
    "    download_and_extract(resource, compressed_file, root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# MSD dataset structure follows the following convention:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "test_dir = os.path.join(dataroot, \"imagesTs/\")\n",
    "train_dir = os.path.join(dataroot, \"imagesTr/\")\n",
    "label_dir = os.path.join(dataroot, \"labelsTr/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Construct skeleton JSON to populate with your own data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "datalist_json = {\"testing\": [], \"training\": []}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Populate JSON with test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "datalist_json[\"testing\"] = [\n",
    "    {\"image\": \"./imagesTs/\" + file} for file in os.listdir(test_dir) if (\".nii.gz\" in file) and (\"._\" not in file)\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Visualise testing data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'image': './imagesTs/spleen_54.nii.gz'},\n",
       " {'image': './imagesTs/spleen_42.nii.gz'},\n",
       " {'image': './imagesTs/spleen_7.nii.gz'},\n",
       " {'image': './imagesTs/spleen_39.nii.gz'},\n",
       " {'image': './imagesTs/spleen_30.nii.gz'},\n",
       " {'image': './imagesTs/spleen_43.nii.gz'},\n",
       " {'image': './imagesTs/spleen_1.nii.gz'},\n",
       " {'image': './imagesTs/spleen_51.nii.gz'},\n",
       " {'image': './imagesTs/spleen_34.nii.gz'},\n",
       " {'image': './imagesTs/spleen_11.nii.gz'}]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datalist_json[\"testing\"][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Populate with training images and labels in your directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "datalist_json[\"training\"] = [\n",
    "    {\"image\": \"./imagesTr/\" + file, \"label\": \"./labelsTr/\" + file, \"fold\": 0}\n",
    "    for file in os.listdir(train_dir)\n",
    "    if (\".nii.gz\" in file) and (\"._\" not in file)\n",
    "]  # Initialize as single fold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Visualise training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'image': './imagesTr/.lock', 'label': './labelsTr/.lock', 'fold': 0},\n",
       " {'image': './imagesTr/spleen_17.nii.gz',\n",
       "  'label': './labelsTr/spleen_17.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_62.nii.gz',\n",
       "  'label': './labelsTr/spleen_62.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_52.nii.gz',\n",
       "  'label': './labelsTr/spleen_52.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_10.nii.gz',\n",
       "  'label': './labelsTr/spleen_10.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/datastore_v2.json',\n",
       "  'label': './labelsTr/datastore_v2.json',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/labels', 'label': './labelsTr/labels', 'fold': 0},\n",
       " {'image': './imagesTr/spleen_29.nii.gz',\n",
       "  'label': './labelsTr/spleen_29.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_28.nii.gz',\n",
       "  'label': './labelsTr/spleen_28.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_14.nii.gz',\n",
       "  'label': './labelsTr/spleen_14.nii.gz',\n",
       "  'fold': 0}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datalist_json[\"training\"][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Randomise training data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'image': './imagesTr/spleen_14.nii.gz',\n",
       "  'label': './labelsTr/spleen_14.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_52.nii.gz',\n",
       "  'label': './labelsTr/spleen_52.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_21.nii.gz',\n",
       "  'label': './labelsTr/spleen_21.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_44.nii.gz',\n",
       "  'label': './labelsTr/spleen_44.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_16.nii.gz',\n",
       "  'label': './labelsTr/spleen_16.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_40.nii.gz',\n",
       "  'label': './labelsTr/spleen_40.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_9.nii.gz',\n",
       "  'label': './labelsTr/spleen_9.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_41.nii.gz',\n",
       "  'label': './labelsTr/spleen_41.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_10.nii.gz',\n",
       "  'label': './labelsTr/spleen_10.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_32.nii.gz',\n",
       "  'label': './labelsTr/spleen_32.nii.gz',\n",
       "  'fold': 0}]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(42)\n",
    "random.shuffle(datalist_json[\"training\"])\n",
    "datalist_json[\"training\"][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Split training data into N random folds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "num_folds = 5\n",
    "fold_size = len(datalist_json[\"training\"]) // num_folds\n",
    "for i in range(num_folds):\n",
    "    for j in range(fold_size):\n",
    "        datalist_json[\"training\"][i * fold_size + j][\"fold\"] = i"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Visualise final training data with all randomised folds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'image': './imagesTr/spleen_14.nii.gz',\n",
       "  'label': './labelsTr/spleen_14.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_52.nii.gz',\n",
       "  'label': './labelsTr/spleen_52.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_21.nii.gz',\n",
       "  'label': './labelsTr/spleen_21.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_44.nii.gz',\n",
       "  'label': './labelsTr/spleen_44.nii.gz',\n",
       "  'fold': 0},\n",
       " {'image': './imagesTr/spleen_16.nii.gz',\n",
       "  'label': './labelsTr/spleen_16.nii.gz',\n",
       "  'fold': 0}]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datalist_json[\"training\"][:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Save JSON to file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Datalist is saved to msd_task09_spleen_folds.json\n"
     ]
    }
   ],
   "source": [
    "datalist_file = \"msd_\" + msd_task.lower() + \"_folds.json\"\n",
    "with open(datalist_file, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(datalist_json, f, ensure_ascii=False, indent=4)\n",
    "print(f\"Datalist is saved to {datalist_file}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Cleanup temporary files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
